{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Learning selection with inference in the full model\n",
    "\n",
    "This is the same example as considered in [Liu et al.](https://arxiv.org/abs/1801.09037) though we\n",
    "do not consider the special analysis in that paper. We let the computer\n",
    "guide us in correcting for selection."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/jonathantaylor/anaconda/envs/py36/lib/python3.6/site-packages/sklearn/ensemble/weight_boosting.py:29: DeprecationWarning: numpy.core.umath_tests is an internal NumPy module and should not be imported. It will be removed in a future NumPy release.\n",
      "  from numpy.core.umath_tests import inner1d\n",
      "Using TensorFlow backend.\n",
      "/Users/jonathantaylor/anaconda/envs/py36/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:455: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_qint8 = np.dtype([(\"qint8\", np.int8, 1)])\n",
      "/Users/jonathantaylor/anaconda/envs/py36/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:456: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_quint8 = np.dtype([(\"quint8\", np.uint8, 1)])\n",
      "/Users/jonathantaylor/anaconda/envs/py36/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:457: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_qint16 = np.dtype([(\"qint16\", np.int16, 1)])\n",
      "/Users/jonathantaylor/anaconda/envs/py36/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:458: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_quint16 = np.dtype([(\"quint16\", np.uint16, 1)])\n",
      "/Users/jonathantaylor/anaconda/envs/py36/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:459: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_qint32 = np.dtype([(\"qint32\", np.int32, 1)])\n",
      "/Users/jonathantaylor/anaconda/envs/py36/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:462: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  np_resource = np.dtype([(\"resource\", np.ubyte, 1)])\n",
      "R[write to console]: Loaded gbm 2.1.5\n",
      "\n",
      "R[write to console]: randomForest 4.6-14\n",
      "\n",
      "R[write to console]: Type rfNews() to see new features/changes/bug fixes.\n",
      "\n"
     ]
    }
   ],
   "source": [
    "import functools\n",
    "\n",
    "import numpy as np\n",
    "from scipy.stats import norm as ndist\n",
    "\n",
    "import statsmodels.api as sm\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "import pandas as pd\n",
    "\n",
    "import regreg.api as rr\n",
    "\n",
    "from selectinf.tests.instance import gaussian_instance\n",
    "\n",
    "from selectinf.learning.utils import full_model_inference, pivot_plot\n",
    "from selectinf.learning.core import normal_sampler\n",
    "from selectinf.learning.Rfitters import logit_fit\n",
    "\n",
    "def simulate(n=200, p=100, s=10, signal=(0.5, 1), sigma=2, alpha=0.1, B=2000):\n",
    "\n",
    "    # description of statistical problem\n",
    "\n",
    "    X, y, truth = gaussian_instance(n=n,\n",
    "                                    p=p, \n",
    "                                    s=s,\n",
    "                                    equicorrelated=False,\n",
    "                                    rho=0.5, \n",
    "                                    sigma=sigma,\n",
    "                                    signal=signal,\n",
    "                                    random_signs=True,\n",
    "                                    scale=False)[:3]\n",
    "\n",
    "    dispersion = sigma**2\n",
    "\n",
    "    S = X.T.dot(y)\n",
    "    covS = dispersion * X.T.dot(X)\n",
    "    sampler = normal_sampler(S, covS)\n",
    "\n",
    "    def meta_algorithm(XTX, XTXi, lam, sampler):\n",
    "\n",
    "        p = XTX.shape[0]\n",
    "        success = np.zeros(p)\n",
    "\n",
    "        loss = rr.quadratic_loss((p,), Q=XTX)\n",
    "        pen = rr.l1norm(p, lagrange=lam)\n",
    "\n",
    "        scale = 0.\n",
    "        noisy_S = sampler(scale=scale)\n",
    "        loss.quadratic = rr.identity_quadratic(0, 0, -noisy_S, 0)\n",
    "        problem = rr.simple_problem(loss, pen)\n",
    "        soln = problem.solve(max_its=100, tol=1.e-10)\n",
    "        success += soln != 0\n",
    "        return set(np.nonzero(success)[0])\n",
    "\n",
    "    XTX = X.T.dot(X)\n",
    "    XTXi = np.linalg.inv(XTX)\n",
    "    resid = y - X.dot(XTXi.dot(X.T.dot(y)))\n",
    "    dispersion = np.linalg.norm(resid)**2 / (n-p)\n",
    "                         \n",
    "    lam = 4. * np.sqrt(n)\n",
    "    selection_algorithm = functools.partial(meta_algorithm, XTX, XTXi, lam)\n",
    "\n",
    "    # run selection algorithm\n",
    "\n",
    "\n",
    "    return full_model_inference(X,\n",
    "                                y,\n",
    "                                truth,\n",
    "                                selection_algorithm,\n",
    "                                sampler,\n",
    "                                success_params=(1, 1),\n",
    "                                B=B,\n",
    "                                fit_probability=logit_fit,\n",
    "                                fit_args={'df':20},\n",
    "                                how_many=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "prob(select):  [0.762]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/jonathantaylor/git-repos/selectinf/selectinf/distributions/discrete_family.py:86: RuntimeWarning: divide by zero encountered in log\n",
      "  self._lw = np.array([np.log(v) for v in xw[:,1]])\n"
     ]
    }
   ],
   "source": [
    "for i in range(50):\n",
    "    df = simulate()\n",
    "    csvfile = 'lasso_exact.csv'\n",
    "    outbase = csvfile[:-4]\n",
    "\n",
    "    if df is not None and i > 0:\n",
    "\n",
    "        try: # concatenate to disk\n",
    "            df = pd.concat([df, pd.read_csv(csvfile)])\n",
    "        except FileNotFoundError:\n",
    "            pass\n",
    "        df.to_csv(csvfile, index=False)\n",
    "\n",
    "        if len(df['pivot']) > 0:\n",
    "            pivot_ax = pivot_plot(df, outbase)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "jupytext": {
   "cell_metadata_filter": "all,-slideshow",
   "formats": "ipynb,Rmd"
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
